# -*- coding: utf-8 -*-
"""
Created on Sun Feb 18 10:07:00 2024

@author: xiehuan
图片对比API服务
process_images_backend
cmd启动命令：uvicorn app:app --reload
测试地址  http://127.0.0.1:8000/docs#/default/process_images_endpoint_process_images_post
"""
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import cv2
import numpy as np
from PIL import Image
import os

app = FastAPI()

# 后端解决跨域
origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 设置图片处理目录
image_dir = "./processed_images"
os.makedirs(image_dir, exist_ok=True)


def crop_images(image_path1: str, image_path2: str):
    image1 = cv2.imread(image_path1)
    image2 = cv2.imread(image_path2)

    gray1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)

    akaze = cv2.AKAZE_create()
    keypoints1, descriptors1 = akaze.detectAndCompute(gray1, None)
    keypoints2, descriptors2 = akaze.detectAndCompute(gray2, None)

    matcher = cv2.DescriptorMatcher_create(cv2.DescriptorMatcher_BRUTEFORCE_HAMMING)
    matches = matcher.match(descriptors1, descriptors2)

    matches = sorted(matches, key=lambda x: x.distance)
    good_matches = matches[:10]

    src_pts = np.float32([keypoints1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
    dst_pts = np.float32([keypoints2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)

    M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)

    aligned_image = cv2.warpPerspective(image1, M, (image2.shape[1], image2.shape[0]))

    h, w = aligned_image.shape[:2]
    corners = np.array([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]], dtype=np.float32)
    transformed_corners = cv2.perspectiveTransform(np.array([corners]), M)
    x_min = int(np.min(transformed_corners[0, :, 0]))
    x_max = int(np.max(transformed_corners[0, :, 0]))
    y_min = int(np.min(transformed_corners[0, :, 1]))
    y_max = int(np.max(transformed_corners[0, :, 1]))

    # 裁剪坐标修正
    x_min = max(0, x_min)
    y_min = max(0, y_min)
    x_max = min(w, x_max)
    y_max = min(h, y_max)

    cropped_image1 = aligned_image[y_min:y_max, x_min:x_max]
    cropped_image2 = image2[y_min:y_max, x_min:x_max]

    return cropped_image1, cropped_image2


def hash_img(img: Image) -> str:
    width, height = 10, 10
    img = img.resize((width, height))
    a = []

    for y in range(img.height):
        b = []
        for x in range(img.width):
            pos = x, y
            color_array = img.getpixel(pos)
            color = sum(color_array) / 3
            b.append(int(color))
        a.append(b)

    hash_img = ""
    for y in range(img.height):
        avg = sum(a[y]) / len(a[y])
        for x in range(img.width):
            if a[y][x] >= avg:
                hash_img += "1"
            else:
                hash_img += "0"

    return hash_img


def similar(img1: Image, img2: Image) -> float:
    hash1 = hash_img(img1)
    hash2 = hash_img(img2)

    difference = sum(abs(int(hash1[i]) - int(hash2[i])) for i in range(len(hash1)))
    similarity = 1 - (difference / len(hash1))

    return similarity


def mark_similar_regions(img1: Image, img2: Image, n: int) -> np.ndarray:
    width, height = img1.size
    grid_width = width // n
    grid_height = height // n

    marked_img = np.array(img2)

    for y in range(n):
        for x in range(n):
            # 获取当前栅格的图像部分
            grid_img1 = img1.crop((x * grid_width, y * grid_height, (x + 1) * grid_width, (y + 1) * grid_height))
            grid_img2 = img2.crop((x * grid_width, y * grid_height, (x + 1) * grid_width, (y + 1) * grid_height))

            # 计算当前栅格的相似度
            similarity = similar(grid_img1, grid_img2)

            # 相似度低于85%，用红框标记当前栅格
            if similarity < 0.85:
                cv2.rectangle(marked_img, (x * grid_width, y * grid_height),
                              ((x + 1) * grid_width, (y + 1) * grid_height), (0, 0, 255), 3)

    return marked_img


@app.post("/process_images")
async def process_images_endpoint(file1: UploadFile = File(...), file2: UploadFile = File(...)):
    try:
        # 保存上传的图片
        file1_path = os.path.join(image_dir, file1.filename)
        file2_path = os.path.join(image_dir, file2.filename)
        with open(file1_path, "wb") as f1, open(file2_path, "wb") as f2:
            f1.write(file1.file.read())
            f2.write(file2.file.read())

        # 处理图片
        cropped_image1, cropped_image2 = crop_images(file1_path, file2_path)

        # 保存处理后的图片
        processed_image1_path = os.path.join(image_dir, "cropped_image1.png")
        processed_image2_path = os.path.join(image_dir, "cropped_image2.png")
        cv2.imwrite(processed_image1_path, cropped_image1)
        cv2.imwrite(processed_image2_path, cropped_image2)

        # 读取处理后的图片
        img1 = Image.open(processed_image1_path)
        img2 = Image.open(processed_image2_path)

        # 设置栅格数量
        n = 32

        # 标记相似度低于85%的栅格
        marked_img = mark_similar_regions(img1, img2, n)

        # 保存和返回标记后的图片
        marked_image = Image.fromarray(marked_img)
        marked_image_path = os.path.join(image_dir, "marked_image.png")
        marked_image.save(marked_image_path)

        return FileResponse(marked_image_path)

    except Exception as e:
        return JSONResponse(content={"error": str(e)}, status_code=500)


if __name__ == "__main__":
    uvicorn.run(
        app='process_images_backend:app',
        host="127.0.0.1",
        port=8000,
        reload=True
    )

